{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare and Deploy a TensorFlow Model to AI Platform for Online Serving\n",
    "\n",
    "This Notebook demonstrates how to prepare a TensorFlow 2.x model and deploy it for serving with AI Platform Prediction. This example uses the pretrained [ResNet V2 101](https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4) image classification model from [TensorFlow Hub](https://tfhub.dev/) (TF Hub).\n",
    "\n",
    "The Notebook covers the following steps:\n",
    "1. Downloading and running the ResNet module from TF Hub\n",
    "2. Creating serving signatures for the module\n",
    "3. Exporting the model as a SavedModel\n",
    "4. Deploying the SavedModel to AI Platform Prediction\n",
    "5. Validating the deployed model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "This Notebook was tested on **AI Platform Notebooks** using the standard TF 2.8 image."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import base64\n",
    "import os\n",
    "import json\n",
    "import requests\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from typing import List, Optional, Text, Tuple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configure GCP environment settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT_ID = '[your-google-project-id]' # Set your project Id\n",
    "BUCKET = '[your-bucket-name]' # Set your bucket name Id\n",
    "REGION = '[your-region]'  # Set your region for deploying the model\n",
    "MODEL_NAME = 'resnet_classifier'\n",
    "MODEL_VERSION = 'v1'\n",
    "GCS_MODEL_LOCATION = 'gs://{}/models/{}/{}'.format(BUCKET, MODEL_NAME, MODEL_VERSION)\n",
    "THUB_MODEL_HANDLE = 'https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4'\n",
    "IMAGENET_LABELS_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'\n",
    "IMAGES_FOLDER = 'test_images'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a local workspace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOCAL_WORKSPACE = '/tmp/workspace'\n",
    "if tf.io.gfile.exists(LOCAL_WORKSPACE):\n",
    "    print(\"Removing previous workspace artifacts...\")\n",
    "    tf.io.gfile.rmtree(LOCAL_WORKSPACE)\n",
    "\n",
    "print(\"Creating a new workspace...\")\n",
    "tf.io.gfile.makedirs(LOCAL_WORKSPACE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Loading and Running the ResNet Module"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1. Download and instantiate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"TFHUB_DOWNLOAD_PROGRESS\"] = 'True'\n",
    "\n",
    "local_savedmodel_path = hub.resolve(THUB_MODEL_HANDLE)\n",
    "\n",
    "print(local_savedmodel_path)\n",
    "!ls -la {local_savedmodel_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = hub.load(THUB_MODEL_HANDLE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The expected input to most TF Hub TF2 image classification models, including ResNet 101, is a rank 4 tensor conforming to the following tensor specification: `tf.TensorSpec([None, height, width, 3], tf.float32)`. For the ResNet 101 model, the expected image size is `height x width = 224 x 224`. The color values for all channels are expected to be normalized to the [0, 1] range. \n",
    "\n",
    "The output of the model is a batch of logits vectors. The indices into the logits are the `num_classes = 1001` classes  from the ImageNet dataset. The mapping from indices to class labels can be found in the [labels file](download.tensorflow.org/data/ImageNetLabels.txt) with class 0 for \"background\", followed by 1000 actual ImageNet classes.\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "We will now test the model on a couple of JPEG images. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2. Display sample images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_list = [tf.io.read_file(os.path.join(IMAGES_FOLDER, image_path))\n",
    "         for image_path in os.listdir(IMAGES_FOLDER)]\n",
    "\n",
    "ncolumns = len(image_list) if len(image_list) < 4 else 4\n",
    "nrows = int(len(image_list) // ncolumns)\n",
    "fig, axes = plt.subplots(nrows=nrows, ncols=ncolumns, figsize=(10,10))\n",
    "for axis, image in zip(axes.flat[0:], image_list):\n",
    "    decoded_image = tf.image.decode_image(image)\n",
    "    axis.set_title(decoded_image.shape)\n",
    "    axis.imshow(decoded_image.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3. Preprocess the testing images\n",
    "\n",
    "The images need to be preprocessed to conform to the format expected by the ResNet101 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _decode_and_scale(image, size):\n",
    "    image = tf.image.decode_image(image, expand_animations=False)\n",
    "        \n",
    "    image_height = image.shape[0]\n",
    "    image_width = image.shape[1]\n",
    "    crop_size = tf.minimum(image_height, image_width)\n",
    "    offset_height = ((image_height - crop_size) + 1) // 2\n",
    "    offset_width = ((image_width - crop_size) + 1) // 2\n",
    "        \n",
    "    image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size)\n",
    "    image = tf.cast(tf.image.resize(image, [size, size]), tf.uint8)\n",
    "    \n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 224\n",
    "\n",
    "raw_images = tf.stack(image_list)\n",
    "preprocessed_images = tf.map_fn(lambda x: _decode_and_scale(x, size), raw_images, dtype=tf.uint8)\n",
    "preprocessed_images = tf.image.convert_image_dtype(preprocessed_images, tf.float32)\n",
    "print(preprocessed_images.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4. Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = model(preprocessed_images)\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model returns a batch of arrays with logits. This is not a very user friendly output so we will convert it to the list of ImageNet class labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_path = tf.keras.utils.get_file(\n",
    "    'ImageNetLabels.txt',\n",
    "    IMAGENET_LABELS_URL)\n",
    "imagenet_labels = np.array(open(labels_path).read().splitlines())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "We will display the 5 highest ranked labels for each image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for prediction in list(predictions):\n",
    "    decoded = imagenet_labels[np.argsort(prediction.numpy())[::-1][:5]]\n",
    "    print(list(decoded))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create Serving Signatures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The inputs and outputs of the model as used during model training may not be optimal for serving. For example, in a typical training pipeline, feature engineering is performed as a separate step preceding model training and hyperparameter tuning. When serving the model, it may be more optimal to embed the feature engineering logic into the serving interface rather than require a client application to preprocess data.\n",
    "\n",
    "\n",
    "The ResNet V2 101 model from TF Hub is optimized for recomposition and fine tuning. Since there are no serving signatures in the model's metadata, it cannot be served with TF Serving as is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(model.signatures)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make it servable, we need to add a serving signature(s) describing the inference method(s) of the model. \n",
    "\n",
    "We will add two signatures: \n",
    "1. **The default signature** - This will expose the default predict method of the ResNet101 model.\n",
    "2. **Prep/post-processing signature** - Since the expected inputs to this interface require a relatively complex image preprocessing to be performed by a client, we will also expose an alternative signature that embeds the preprocessing and postprocessing logic and accepts raw unprocessed images and returns the list of ranked class labels and associated label probabilities. \n",
    "\n",
    "The signatures are created by defining a custom module class derived from the `tf.Module` base class that encapsulates our ResNet model and extends it with a method implementing the image preprocessing and output postprocessing logic. The default method of the custom module is mapped to the default method of the base ResNet module to maintain the analogous interface. \n",
    "\n",
    "The custom module will be exported as `SavedModel` that includes the original model, the preprocessing logic, and two serving signatures.\n",
    "\n",
    "This technique can be generalized to other scenarios where you need to extend a TensorFlow model and you have access to the serialized `SavedModel` but you don't have access to the Python code implementing the model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.1. Define the custom serving module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LABELS_KEY = 'labels'\n",
    "PROBABILITIES_KEY = 'probabilities'\n",
    "NUM_LABELS = 5\n",
    "\n",
    "class ServingModule(tf.Module):\n",
    "    \"\"\"\n",
    "    A custom tf.Module that adds image preprocessing and output post processing to\n",
    "    a base TF 2 image classification model from TF Hub. \n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, base_model, input_size, output_labels):\n",
    "        super(ServingModule, self).__init__()\n",
    "        self._model = base_model\n",
    "        self._input_size = input_size\n",
    "        self._output_labels = tf.constant(output_labels, dtype=tf.string)\n",
    "        \n",
    "\n",
    "    def _decode_and_scale(self, raw_image):\n",
    "        \"\"\"\n",
    "        Decodes, crops, and resizes a single raw image.\n",
    "        \"\"\"\n",
    "        \n",
    "        image = tf.image.decode_image(raw_image, dtype=tf.dtypes.uint8, expand_animations=False)\n",
    "        image_shape = tf.shape(image)\n",
    "        image_height = image_shape[0]\n",
    "        image_width = image_shape[1]\n",
    "        crop_size = tf.minimum(image_height, image_width)\n",
    "        offset_height = ((image_height - crop_size) + 1) // 2\n",
    "        offset_width = ((image_width - crop_size) + 1) // 2\n",
    "        \n",
    "        image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size)\n",
    "        image = tf.image.resize(image, [self._input_size, self._input_size])\n",
    "        image = tf.cast(image, tf.uint8)\n",
    "    \n",
    "        return image\n",
    "    \n",
    "    def _preprocess(self, raw_inputs):\n",
    "        \"\"\"\n",
    "        Preprocesses raw inputs as sent by the client.\n",
    "        \"\"\"\n",
    "        \n",
    "        # A mitigation for https://github.com/tensorflow/tensorflow/issues/28007\n",
    "        with tf.device('/cpu:0'):\n",
    "            images = tf.map_fn(self._decode_and_scale, raw_inputs, dtype=tf.uint8)\n",
    "        images = tf.image.convert_image_dtype(images, tf.float32)\n",
    "        \n",
    "        return images\n",
    "        \n",
    "    def _postprocess(self, model_outputs):\n",
    "        \"\"\"\n",
    "        Postprocesses outputs returned by the base model.\n",
    "        \"\"\"\n",
    "        \n",
    "        probabilities = tf.nn.softmax(model_outputs)\n",
    "        indices = tf.argsort(probabilities, axis=1, direction='DESCENDING')\n",
    "        \n",
    "        return {\n",
    "            LABELS_KEY: tf.gather(self._output_labels, indices, axis=-1)[:,:NUM_LABELS],\n",
    "            PROBABILITIES_KEY: tf.sort(probabilities, direction='DESCENDING')[:,:NUM_LABELS]\n",
    "        }\n",
    "        \n",
    "\n",
    "    @tf.function(input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)])\n",
    "    def __call__(self, x):\n",
    "        \"\"\"\n",
    "        A pass-through to the base model.\n",
    "        \"\"\"\n",
    "        \n",
    "        return self._model(x)\n",
    "\n",
    "    @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])\n",
    "    def predict_labels(self, raw_images):\n",
    "        \"\"\"\n",
    "        Preprocesses inputs, calls the base model \n",
    "        and postprocesses outputs from the base model.\n",
    "        \"\"\"\n",
    "        \n",
    "        # Call the preprocessing handler\n",
    "        images = self._preprocess(raw_images)\n",
    "        \n",
    "        # Call the base model\n",
    "        logits = self._model(images)\n",
    "        \n",
    "        # Call the postprocessing handler\n",
    "        outputs = self._postprocess(logits)\n",
    "        \n",
    "        return outputs\n",
    "        \n",
    "    \n",
    "serving_module = ServingModule(model, 224, imagenet_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2. Test the custom serving module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = serving_module.predict_labels(raw_images)\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Save the custom serving module as `SavedModel`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = os.path.join(LOCAL_WORKSPACE, MODEL_NAME, MODEL_VERSION)\n",
    "\n",
    "default_signature = serving_module.__call__.get_concrete_function()\n",
    "preprocess_signature = serving_module.predict_labels.get_concrete_function()\n",
    "signatures = {\n",
    "    'serving_default': default_signature,\n",
    "    'serving_preprocess': preprocess_signature\n",
    "}\n",
    "\n",
    "tf.saved_model.save(serving_module, model_path, signatures=signatures)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1. Inspect the `SavedModel`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!saved_model_cli show --dir {model_path} --tag_set serve --all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2. Test loading and executing the `SavedModel`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.load_model(model_path)\n",
    "model.predict_labels(raw_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Deploy the `SavedModel` to AI Platform Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1. Copy the `SavedModel` to GCS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gsutil cp -r {model_path} {GCS_MODEL_LOCATION}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gsutil ls {GCS_MODEL_LOCATION}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Create a model in AI Platform Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud ai-platform models create {MODEL_NAME} \\\n",
    "  --project {PROJECT_ID} \\\n",
    "  --region {REGION}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud ai-platform models list \\\n",
    "  --project {PROJECT_ID} \\\n",
    "  --region {REGION}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 Create a model version "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MACHINE_TYPE='n1-standard-8'\n",
    "ACCELERATOR='count=1,type=nvidia-tesla-p4'\n",
    "\n",
    "!gcloud beta ai-platform versions create {MODEL_VERSION} \\\n",
    "  --model={MODEL_NAME} \\\n",
    "  --origin={GCS_MODEL_LOCATION} \\\n",
    "  --runtime-version=2.8 \\\n",
    "  --framework=TENSORFLOW \\\n",
    "  --python-version=3.7 \\\n",
    "  --machine-type={MACHINE_TYPE} \\\n",
    "  --accelerator={ACCELERATOR} \\\n",
    "  --project={PROJECT_ID} \\\n",
    "  --region={REGION}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud ai-platform versions list \\\n",
    "  --model={MODEL_NAME} --project={PROJECT_ID} --region={REGION}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Validate the Deployed Model Version to AI Platform Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import googleapiclient.discovery\n",
    "from google.api_core.client_options import ClientOptions\n",
    "\n",
    "prefix = '{}-ml'.format(REGION) if REGION else 'ml'\n",
    "api_endpoint = 'https://{}.googleapis.com'.format(prefix)\n",
    "client_options = ClientOptions(api_endpoint=api_endpoint)\n",
    "service = googleapiclient.discovery.build('ml', 'v1',\n",
    "                                          cache_discovery=False,\n",
    "                                          client_options=client_options)\n",
    "name = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID, MODEL_NAME, MODEL_VERSION)\n",
    "print(\"Service name: {}\".format(name))\n",
    "\n",
    "def caip_predict(instances, signature_name='serving_default'):\n",
    "    \n",
    "    request_body={\n",
    "        'signature_name': signature_name,\n",
    "        'instances': instances}\n",
    "    \n",
    "    response = service.projects().predict(\n",
    "        name=name,\n",
    "        body=request_body\n",
    "    \n",
    "    ).execute()\n",
    "\n",
    "    if 'error' in response:\n",
    "        raise RuntimeError(response['error'])\n",
    "\n",
    "    outputs = response['predictions']\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "signature_name = 'serving_preprocess'\n",
    "\n",
    "encoded_images = [{'b64': base64.b64encode(image.numpy()).decode('utf-8')} \n",
    "                  for image in image_list]  \n",
    "\n",
    "caip_predict(encoded_images, signature_name=signature_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## License\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at [https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the specific language governing permissions and limitations under the License."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "name": "tf2-2-2-gpu.2-2.m50",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-2-2-gpu.2-2:m50"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
